import torch
import torch.nn as nn
import torch.nn.functional as F


class AveragedSampleMarginLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=-1)
        l2 = torch.sum(logits * (1 - label_one_hot), dim=-1)
        loss = -l1 + self.alpha * l2
        return loss.mean()